from urllib.parse import quote_plus
import os
import asyncio
from time import time
import argparse

import pandas as pd
from bs4 import BeautifulSoup
from pyppeteer import launch


BASE_URL = "https://www.google.com/search"


def update_eligible(qrys, curr_data):
    qrys = qrys.loc[~qrys.qry_id.isin(curr_data.qry_id.values)]
    print(
        f"{len(curr_data)} queries have been parsed. {len(qrys)} queries left to parse."
    )
    return qrys


def log(curr_data, new, fp_output):
    new = pd.DataFrame(new, columns=["qry_id", "google_entity", "html_len"])
    new = new.loc[new.google_entity.notna(), :]
    curr_data = pd.concat([curr_data, new])
    curr_data.to_csv(fp_output, index=False)
    return curr_data


def parse_serp(html):
    """ ***Note: these identifiers may have changed since original 2022 data collection***"""
    soup = BeautifulSoup(html, "lxml")
    top_bar = soup.find("div", {"class": "XqFnDf"})
    rhs = soup.find("div", {"id": "rhs"})
    if top_bar and top_bar.find("div", {"class": "kp-wholepage-osrp"}):
        return True
    elif rhs and rhs.find("div", {"class": "kp-wholepage"}):
        return True
    else:
        return False


async def search_and_parse(browser, qry, qry_id, min_html_len=200_000):
    """search query and parse html"""
    try:
        page = await browser.newPage()
        await page.goto(f"{BASE_URL}?q={quote_plus(qry)}")
        html = await page.content()
        if args.debug or len(html) < min_html_len:
            await page.screenshot({"path": "debug_ner.png"})
        if len(html) < min_html_len:
            return (qry_id, None, len(html))
        else:
            entity = parse_serp(html)
            return (qry_id, entity, len(html))
    except:
        return (qry_id, None, 0)


async def main(args):

    if args.debug:
        browser = await launch()
        results = await search_and_parse(browser, "steve carell", 0)
        print(results)
        await browser.close()

    else:
        qrys = pd.read_csv(args.fp_input)
        if os.path.exists(args.fp_output):
            curr_data = pd.read_csv(args.fp_output)
            qrys = update_eligible(qrys, curr_data)
        else:
            curr_data = pd.DataFrame([], columns=["qry_id", "google_entity"])

        try:
            success_rate = 1
            while len(qrys) and success_rate >= args.success_rate:
                st = time()
                browser = await launch()
                tasks = [
                    asyncio.ensure_future(
                        search_and_parse(browser, row.qry, row.qry_id)
                    )
                    for _, row in qrys.head(args.threads).iterrows()
                ]
                results = await asyncio.gather(*tasks)

                print(f"crawling {len(results)} queries took {time()-st}s")
                success_rate = 1 - sum([res[1] is None for res in results]) / len(
                    results
                )
                print(f"success rate: {success_rate}")
                await browser.close()

                curr_data = log(curr_data, results, args.fp_output)
                qrys = update_eligible(qrys, curr_data)

            if success_rate < args.success_rate:
                print(f"success rate fell below {args.success_rate}")
            else:
                print("crawled all queries")

        except:
            await browser.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Parse Google SERPs to identify named entities missed by spacy"
    )
    parser.add_argument(
        "-i",
        "--fp_input",
        type=str,
        help="Input file.",
    )
    parser.add_argument(
        "-s",
        "--fp_output",
        type=str,
        help="Output file.",
    )
    parser.add_argument(
        "-t",
        "--threads",
        default=50,
        type=int,
        help="Number of queries to crawl in parallel. Default=50",
    )
    parser.add_argument(
        "-r",
        "--success_rate",
        default=1,
        type=float,
        help="Kill crawl if success rate falls below this number. Default=0.5",
    )
    parser.add_argument(
        "--debug",
        default=False,
        action=argparse.BooleanOptionalAction,
        help="Whether to run in debug mode (parse 1 query). Default=False",
    )
    args = parser.parse_args()
    asyncio.run(main(args))
